// Copyright (c) Orbbec Inc. All Rights Reserved.
// Licensed under the MIT License.

#include "PointCloudSaveUtil.hpp"
#include "logger/Logger.hpp"
#include "logger/LoggerInterval.hpp"

#include <cmath>

namespace libobsensor {

struct PixelCoord {
    unsigned int x;
    unsigned int y;
    PixelCoord(unsigned int _x, unsigned int _y) : x(_x), y(_y) {}
    bool operator==(const PixelCoord &other) const {
        return (x == other.x && y == other.y);
    }
    bool operator<(const PixelCoord &other) const {
        if(y != other.y)
            return y < other.y;
        else if(x != other.x)
            return x < other.x;
        else
            return false;
    }
};

struct Triangle {
    unsigned int x, y, z;
    Triangle(unsigned int _x = 0, unsigned int _y = 0, unsigned int _z = 0) : x(_x), y(_y), z(_z) {}
};

struct Vertex {
    float   x, y, z;
    uint8_t color[3];  // color info
    Vertex(float _x = 0.0f, float _y = 0.0f, float _z = 0.0f) : x(_x), y(_y), z(_z) {}
};

struct MeshData {
    std::vector<Vertex>   vertices;
    std::vector<Triangle> faces;
};

static const auto minPointValue = 1e-6f;

inline float computePointsDistance(const Vertex &v1, const Vertex &v2) {
    float diffX = v1.x - v2.x;
    float diffY = v1.y - v2.y;
    float diffZ = v1.z - v2.z;
    return std::sqrt(diffX * diffX + diffY * diffY + diffZ * diffZ);
}

void savePointCloud(const char *fileName, MeshData meshData, bool useMesh, bool colorPointCloud, bool saveBinary) {
    std::ofstream plyOut(fileName);
    if(!plyOut) {
        std::cerr << "Error: Cannot open file " << fileName << " for writing!" << std::endl;
        return;
    }

    plyOut << "ply\n";
    if(saveBinary) {
        plyOut << "format binary_little_endian 1.0\n";
    }
    else {
        plyOut << "format ascii 1.0\n";
    }

    plyOut << "comment Generated by mesh generation code\n";
    plyOut << "element vertex " << meshData.vertices.size() << "\n";
    plyOut << "property float x\n";
    plyOut << "property float y\n";
    plyOut << "property float z\n";
    if(colorPointCloud) {
        plyOut << "property uchar red\n";
        plyOut << "property uchar green\n";
        plyOut << "property uchar blue\n";
    }

    if(useMesh) {
        plyOut << "element face " << meshData.faces.size() << "\n";
        plyOut << "property list uchar int vertex_indices\n";
    }

    plyOut << "end_header\n";

    if(saveBinary) {
        plyOut.close();
        plyOut.open(fileName, std::ios_base::app | std::ios_base::binary);
        for(size_t i = 0; i < meshData.vertices.size(); i++) {
            // write vertices
            plyOut.write(reinterpret_cast<const char *>(&(meshData.vertices[i].x)), sizeof(float));
            plyOut.write(reinterpret_cast<const char *>(&(meshData.vertices[i].y)), sizeof(float));
            plyOut.write(reinterpret_cast<const char *>(&(meshData.vertices[i].z)), sizeof(float));

            if(colorPointCloud) {
                plyOut.write(reinterpret_cast<const char *>(&meshData.vertices[i].color[0]), sizeof(uint8_t));
                plyOut.write(reinterpret_cast<const char *>(&meshData.vertices[i].color[1]), sizeof(uint8_t));
                plyOut.write(reinterpret_cast<const char *>(&meshData.vertices[i].color[2]), sizeof(uint8_t));
            }
        }

        // write faces
        if(useMesh) {
            for(size_t i = 0; i < meshData.faces.size(); i++) {
                static const int three = 3;
                plyOut.write(reinterpret_cast<const char *>(&three), sizeof(uint8_t));
                plyOut.write(reinterpret_cast<const char *>(&meshData.faces[i].x), sizeof(int));
                plyOut.write(reinterpret_cast<const char *>(&meshData.faces[i].y), sizeof(int));
                plyOut.write(reinterpret_cast<const char *>(&meshData.faces[i].z), sizeof(int));
            }
        }
    }
    else {
        // write vertices
        for(size_t i = 0; i < meshData.vertices.size(); i++) {
            plyOut << meshData.vertices[i].x << " " << meshData.vertices[i].y << " " << meshData.vertices[i].z << "\n";
            if(colorPointCloud) {
                plyOut << static_cast<int>(meshData.vertices[i].color[0]) << " " << static_cast<int>(meshData.vertices[i].color[1]) << " "
                       << static_cast<int>(meshData.vertices[i].color[2]) << "\n";
            }
        }

        // write faces
        if(useMesh) {
            for(size_t i = 0; i < meshData.faces.size(); i++) {
                plyOut << "3 " << meshData.faces[i].x << " " << meshData.faces[i].y << " " << meshData.faces[i].z << "\n";
            }
        }
    }

    plyOut.close();
}
bool PointCloudSaveUtil::savePointCloudToPly(const char *fileName, std::shared_ptr<Frame> frame, bool saveBinary, bool useMesh, float meshThreshold) {
    if(!frame) {
        LOG_WARN("depth point cloud frame is null");
        return false;
    }

    auto pointCloudFrame = frame->as<libobsensor::PointsFrame>();
    auto pointCloudType  = pointCloudFrame->getFormat();

    if(pointCloudType != OB_FORMAT_POINT && pointCloudType != OB_FORMAT_RGB_POINT) {
        LOG_WARN("point cloud format invalid");
        return false;
    }

    MeshData meshData;
    uint32_t width  = pointCloudFrame->getWidth();
    uint32_t height = pointCloudFrame->getHeight();

    std::map<PixelCoord, unsigned int> vertexIndexMap;
    std::vector<Vertex>                vertices;
    vertices.reserve(width * height);

    auto data = pointCloudFrame->getData();

    if(pointCloudType == OB_FORMAT_POINT) {
        OBPoint *points = reinterpret_cast<OBPoint *>(const_cast<uint8_t *>(data));
        for(uint32_t y = 0; y < height; ++y) {
            for(uint32_t x = 0; x < width; ++x) {
                int         index = y * width + x;
                const auto &pt    = points[index];
                if(std::fabs(pt.z) >= minPointValue) {
                    vertices.emplace_back(Vertex(pt.x, pt.y, pt.z));
                    vertexIndexMap[PixelCoord(y, x)] = static_cast<unsigned int>(vertices.size()) - 1;
                }
            }
        }
    }
    else if(pointCloudType == OB_FORMAT_RGB_POINT) {
        OBColorPoint *points = reinterpret_cast<OBColorPoint *>(const_cast<uint8_t *>(data));
        for(uint32_t y = 0; y < height; ++y) {
            for(uint32_t x = 0; x < width; ++x) {
                int         index = y * width + x;
                const auto &pt    = points[index];
                if(std::fabs(pt.z) >= minPointValue) {
                    vertices.emplace_back(Vertex(pt.x, pt.y, pt.z));
                    vertices.back().color[0]         = static_cast<uint8_t>(pt.r);
                    vertices.back().color[1]         = static_cast<uint8_t>(pt.g);
                    vertices.back().color[2]         = static_cast<uint8_t>(pt.b);
                    vertexIndexMap[PixelCoord(y, x)] = static_cast<unsigned int>(vertices.size()) - 1;
                }
            }
        }
    }

    if(vertices.empty()) {
        LOG_WARN("vertices is zero");
        return false;
    }

    std::vector<Triangle> faces;
    if(useMesh) {
        faces.reserve(vertices.size() * 3);

        for(uint32_t y = 1; y < height; ++y) {
            for(uint32_t x = 0; x + 1 < width; ++x) {
                unsigned int idxA = 0, idxB = 0, idxC = 0, idxD = 0;
                bool         hasA = false, hasB = false, hasC = false, hasD = false;

                PixelCoord coordA(y, x);
                PixelCoord coordB(y - 1, x);
                PixelCoord coordC(y - 1, x + 1);
                PixelCoord coordD(y, x + 1);

                std::map<PixelCoord, unsigned int>::iterator it;

                if((it = vertexIndexMap.find(coordA)) != vertexIndexMap.end()) {
                    hasA = true;
                    idxA = it->second;
                }
                if((it = vertexIndexMap.find(coordB)) != vertexIndexMap.end()) {
                    hasB = true;
                    idxB = it->second;
                }
                if((it = vertexIndexMap.find(coordC)) != vertexIndexMap.end()) {
                    hasC = true;
                    idxC = it->second;
                }
                if((it = vertexIndexMap.find(coordD)) != vertexIndexMap.end()) {
                    hasD = true;
                    idxD = it->second;
                }

                // Triangle: A-B-C
                if(hasA && hasB && hasC) {
                    if(idxA != idxB && idxA != idxC && idxB != idxC) {
                        float dist1 = computePointsDistance(vertices[idxA], vertices[idxB]);
                        float dist2 = computePointsDistance(vertices[idxA], vertices[idxC]);
                        if(dist1 < meshThreshold && dist2 < meshThreshold) {
                            faces.emplace_back(Triangle(idxA, idxC, idxB));
                        }
                    }
                }

                // Triangle: A-C-D
                if(hasA && hasC && hasD) {
                    if(idxA != idxC && idxA != idxD && idxC != idxD) {
                        float dist1 = computePointsDistance(vertices[idxA], vertices[idxD]);
                        float dist2 = computePointsDistance(vertices[idxA], vertices[idxC]);
                        if(dist1 < meshThreshold && dist2 < meshThreshold) {
                            faces.emplace_back(Triangle(idxA, idxD, idxC));
                        }
                    }
                }
            }
        }

        if(faces.empty()) {
            LOG_WARN("faces is zero");
            return false;
        }

        meshData.faces.swap(faces);
    }

    meshData.vertices.swap(vertices);

    bool colorPointCloud = (pointCloudType == OB_FORMAT_RGB_POINT);
    savePointCloud(fileName, meshData, useMesh, colorPointCloud, saveBinary);
    return true;
}

}  // namespace libobsensor
